import numpy
import atexit
from sys  import getrefcount
from os   import remove
from copy import copy
from .units     import Units
from .utils     import subspace_array, iterindices
from .filearray import FileArray

# --------------------------------------------------------------------
# Set of partitions' temporary files
#
# For example:
# >>> _temporary_files
# set(['/tmp/cf_array_B8SSw2.npy',
#      '/tmp/cf_array_iRekAW.npy'])
# --------------------------------------------------------------------
_temporary_files = set()

def _remove_temporary_files(filename=None):
    '''

Remove temporary partition files from disk.

The removed files' names are deleted from the _temporary_files set.

It is intended to delete individual files as part of the garbage
collection process and to delete all files when python exits. It is
not recommended to be as a general tidy-up function.

This is quite brutal and may break partitions if used unwisely. It is
not recommended to be used as a general tidy-up function.

:Parameters:

    filename : str, optional
        The name of file to remove. The file name must be in the
        _temporary_files set. By default all files given in the
        _temporary_files set are removed.

:Returns:

    None

**Examples**

>>> _temporary_files
set(['/tmp/cf_array_B8SSw2.npy',
     '/tmp/cf_array_G756ks.npy',
     '/tmp/cf_array_iRekAW.npy'])
>>> _remove_temporary_files('/tmp/cf_array_G756ks.npy')
>>> _temporary_files
set(['/tmp/cf_array_B8SSw2.npy',
     '/tmp/cf_array_iRekAW.npy'])
>>> _remove_temporary_files()
>>> _temporary_files
set()

'''
    if filename is not None:
        if filename in _temporary_files:
            # Remove the given temporary file        
            _temporary_files.remove(filename)
            remove(filename)
        #--- End: if
        return
    #--- End: if

    # Still here? Then remove all temporary files
    for filename in _temporary_files:
        remove(filename)

    _temporary_files.clear()
#--- End: def

# --------------------------------------------------------------------
# Instruction to remove all of the temporary files from all partition
# arrays at exit.
# --------------------------------------------------------------------
atexit.register(_remove_temporary_files)


# ====================================================================
#
# Partition object
#
# ====================================================================

class Partition(object):
    '''

A partition.

'''
    __slots__ = ('data',
                 'direction',
                 'location',
                 'order',
                 'part',
                 'shape',
                 'Units',
                 '_save',
                 '_original',
                 )

    def __init__(self, **kwargs):
        ''' 

**Initialization**

:Parameters:

    data : numpy array-like, optional
        The data for the partition. Must be a numpy array or any array
        storing object with a similar interface.

    direction : dict or bool, optional
        The direction of each dimension of the partition's data. It is
        a boolean if the partition's data is a scalar array, otherwise
        it is a dictionary keyed by the dimensions' identities as
        found in `order`.

    location : list, optional
        The location of the partition's data in the master array.

    order : list, optional
        The identities of the dimensions of the partition's data. If
        the partition's data a scalar array then it is an empty list.

    part : list, optional
        The subspace of the partition's data to be returned when it is
        accessed. If the partition's data is to be returned complete
        then `part` may be an empty list.

    shape : list, optional
        The shape of the partition's data as a subspace of the master
        array. If the master array is a scalar array then `shape` is
        an empty list.

    Units : Units, optional
        The units of the partition's data.

**Examples**

>>> p = Partition(data      = numpy.arange(20).reshape(2,5,1),
...               direction = {'dim0', True, 'dim1': False, 'dim2': True},
...               location  = [(0, 6), (1, 3), (4, 5)],
...               order     = ['dim1', 'dim0', 'dim2'],
...               shape     = [5, 2, 1],
...               part      = [],
...               Units     = cf.Units('K'))

>>> p = Partition(data      = numpy.arange(20).reshape(2,5,1),
...               direction = {'dim0', True, 'dim1': False, 'dim2': True},
...               location  = [(0, 6), (1, 3), (4, 5)],
...               order     = ['dim1', 'dim0', 'dim2'],
...               shape     = [5, 2, 1],
...               part      = [slice(None, None, -1), [0,1,3,4], slice(None)],
...               Units     = cf.Units('K'))

>>> p = Partition(data      = numpy.array(4),
...               direction = True,
...               location  = [(4, 5)],
...               order     = ['dim1'],
...               shape     = [1],
...               part      = [],
...               Units     = cf.Units('K'))

'''

        self._save = None
        '''
asfdasdfsa

'''
        self._original = None
        '''

sdasdasdsa

'''

        # Set attributes from keyword arguments
        for attr, value in kwargs.iteritems():
            setattr(self, attr, value)
    #--- End: def

    def __del__(self):
        '''

Called when the partition's reference count reaches zero.

If the partition contains a temporary file which is not referenced by
any other partition then the temporary file is removed from disk.

'''     
        if (hasattr(self.data, '_partition_file') and
            getrefcount(self.data) <= 2):
            # This partition contains a temporary file which is not
            # referenced by any other partition, so remove the file
            # from disk.
            _remove_temporary_files(self.data._partition_file)
    #--- End: def

    def __str__(self):
        '''
x.__str__() <==> str(x)

'''
        return '%s: %s' % \
            (self.__class__.__name__,
             dict((attr, getattr(self, attr, None)) 
                  for attr in self.__slots__))
    #--- End: def

    # ----------------------------------------------------------------
    # Property attribute: indices (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def indices(self):
        '''

The indices of the master array which correspond to this partition's data.

:Returns:

    out : tuple
        A tuple of slice objects or, if the partition's data is a
        scalar array, an empty tuple.

**Examples**

>>> p.data.shape
(5, 7)
>>> p.indices
(slice(0, 5), slice(2, 9))

>>> p.data.shape
()
>>> p.indices
()

'''
        return tuple((slice(*r) for r in self.location))
    #--- End: def

    # ----------------------------------------------------------------
    # Property attribute: in_memory (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def in_memory(self):
        '''

True if and only if the partition's data is in memory as opposed to on
disk.

**Examples**

>>> p.in_memory
False

'''
        return isinstance(self.data, numpy.ndarray)
    #--- End: if

    # ----------------------------------------------------------------
    # Property attribute: on_disk (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def on_disk(self):
        '''

Return True if the partition's data is on disk as opposed to in memory.

**Examples**

>>> p.on_disk
True

'''
        return not self.in_memory
    #--- End: if

    # ----------------------------------------------------------------
    # Property attribute: isscalar (can't set or delete)
    # ----------------------------------------------------------------
    @property
    def isscalar(self):
        '''

True if and only if the partition's data is a scalar array.

**Examples**

>>> p.data.ndim
0
>>> p.isscalar
True

>>> p.data.ndim
1
>>> p.isscalar
False

'''
        return not self.data.ndim
    #--- End: def

    def close(self, save=False):
        '''

Close the partition.

Closing the partition is important for memory management. The
partition should always be closed after it is conformed to prevent
memory leaks.

Closing the partition does one of three things, depending on the
values of the partition's `_original` and `_save` attributes and on
the `save` parameter:
* Nothing.
* Stores the data in a temporary file. 
* Reverts the partition to a previous state.

:Parameters:

    save : bool, optional
        If True and the partition is not to be reverted to a previous
        state then force its data to be stored in a temporary file.

:Returns:

    None

**Examples**

>>> p.close()

'''
        if self._original:
            # The whole partition is to replaced with its
            # pre-conformed state
            self.update_from(self._original)
            self._original = None

        elif self._save or save:
            # The partition's data array is to be saved to a temporary
            # file
            self.to_disk()
            self._save = None
    #--- End: def

    def conform(self, order=None, direction=None, units=None, save=False,
                revert_to_file=False, hardmask=True, dtype=None):
        '''
    
After a partition has been conformed, the partition must be closed
(with the `close` method) before another partition is conformed,
otherwise a memory leak could occur. For example:

>>> order          = partition_array.order
>>> direction      = partition_array.direction
>>> units          = partition_array.units
>>> save           = partition_array.save_to_disk()
>>> revert_to_file = True
>>> for partition in partition_array.flat():
...
...    # Conform the partition
...    partition.conform(order, direction, units, save, revert_to_file)
...
...    # [ Some code to operate on the conformed partition ]
...
...    # Close the partition
...    partition.close()
...
...    # Now move on to conform the next partition 
...
>>>  

:Parameters:

    order : list

    direction : dict

    units : Units

    save : bool, optional
        * If False then the conformed partition's data is to be kept
          in memory when the partition's `close` method is called.

        * If True and `revert_to_file` is False then the conformed
          partition's data will be to be saved to a temporary file on
          disk when the partition's `close` method is called.

        * If True and `revert_to_file` is True and the pre-conformed
          partition's data was in memory then the conformed
          partition's data will be saved to a temporary file on disk
          when the partition's `close` method is called.

        * If True and `revert_to_file` is True and the pre-conformed
          partition's data was on disk then the file pointer will be
          reinstated when the partition's `close` method is called.

    revert_to_file : bool, optional
        * If False and `save` is True then the conformed partition's
          data will be saved to a temporary file on disk when the
          partition's `close` method is called.

        * If True and `save` is True and the pre-conformed partition's
          data was on disk then the file pointer will be reinstated
          when the partition's `close` method is called.

        * Otherwise does nothing.

    dtype : numpy.dtype, optional
        Convert the data array to this data type. By default no
        conversion occurs.

    hardmask : bool, optional
        If False then force the data array's mask to be soft. By
        default the mask is forced to be hard.
    
:Returns: 

    out : numpy array
        The partition's conformed data as a numpy array. The same
        numpy array is stored as an object identity by the partition's
        `data` attribute.

'''
        self._save = save

        p_order     = self.order
        p_direction = self.direction
        p_part      = self.part

        if self.on_disk:
            # --------------------------------------------------------
            # The data is in a file on disk
            # --------------------------------------------------------
            if revert_to_file and save:
                self._original  = self.copy()
                self._save      = None
                _partition_file = None
            else:
                _partition_file = getattr(self.data, '_partition_file', None)
                
            if not p_part:
                indices = Ellipsis
            else:
                indices = tuple(p_part)
                self.part = []
            #--- End: if

            # Read data from a file into a numpy array
            p_data = self.data[indices]
            
            if _partition_file and getrefcount(self.data) <= 2:
                # This partition contains a temporary file which is
                # not referenced by any other partition, so we can
                # remove the file from disk.
                _remove_temporary_files(self.data._partition_file)
            #--- End: if

            # No other object points to this numpy array (because we
            # just created it from a file on disk), so we can change
            # it in place.
            inplace = True

        else:
            # --------------------------------------------------------
            # The data is in a numpy array in memory
            # --------------------------------------------------------
            if getrefcount(self.data) <= 2:
                # No other object points to this numpy array, so we
                # might be able to change it in place.
                inplace = True
            else:
                # At least one other object points to this numpy
                # array, so we can not change it in place.
                inplace = False

            p_data = self.data

            del self.data

            if p_part:
                p_data = subspace_array(p_data, p_part)
                self.part = []
        #--- End: if


        # ------------------------------------------------------------
        # Make sure that the data have the correct units. This process
        # will deep copy the data if required (e.g. if another
        # partition is referencing this numpy array), even if the
        # units are already correct.
        # ------------------------------------------------------------
        # Make sure that we deep copy if the data is not contiguous or
        # is of integer type
        if not p_data.flags['C_CONTIGUOUS'] or p_data.dtype.kind == 'i':
            inplace = False
#consirder a conform method which does in place?

        p_data = Units.conform(p_data, self.Units, units, inplace=inplace)

        # ------------------------------------------------------------
        # Remove excessive size 1 dimensions
        # ------------------------------------------------------------
        if p_order:
            shape        = []
            temp_p_order = []            
            for size, dim in zip(p_data.shape, p_order):
                if dim in order:
                    shape.append(size)
                    temp_p_order.append(dim)
            #--- End: for
            if len(shape) != len(p_order):
                p_order = temp_p_order
                p_data  = numpy.ma.resize(p_data, shape)
        #--- End: if

        # ------------------------------------------------------------
        # Check for reversed dimensions
        # ------------------------------------------------------------
        if p_data.size > 1:
            reversed_dimensions = False
            indices             = [slice(None)] * len(p_order)
            for i, dim in enumerate(p_order):
                if p_direction[dim] != direction[dim]:
                    indices[i]          = slice(None, None, -1)
                    reversed_dimensions = True
            #--- End: for
            if reversed_dimensions:
                p_data = p_data[tuple(indices)]
        #--- End: if

        # ------------------------------------------------------------
        # Insert missing size 1 dimensions
        # ------------------------------------------------------------
        if len(p_order) < len(order):
            for i, dim in enumerate(order):
                if dim not in p_order:
                    p_data = numpy.ma.expand_dims(p_data, i)
                    p_order.insert(i, dim)
            #--- End: for
        #--- End: if

        # ------------------------------------------------------------
        # Reorder axes
        # ------------------------------------------------------------
        axes = [p_order.index(dim) for dim in order]
        if axes != range(len(axes)):
            p_data = numpy.ma.transpose(p_data, axes)

        # ------------------------------------------------------------
        # Make sure the array is a masked array and that the mask is
        # shrunk if possible
        # ------------------------------------------------------------
        if not hasattr(p_data, 'mask'):
            p_data = p_data.view(numpy.ma.MaskedArray)
        else:
            p_data.shrink_mask()

        # ------------------------------------------------------------
        # If a different data-tpye has been speicified then convert
        # the data array
        # ------------------------------------------------------------
        if dtype is not None and dtype != p_data.dtype:
            p_data = p_data.astype(dtype) # One day, astype might have a 'copy' parameter

        # ------------------------------------------------------------
        # Set the hardness of the mask
        # ------------------------------------------------------------
        if hardmask:
            p_data.harden_mask()
        else:
            p_data.soften_mask()

        # ------------------------------------------------------------
        # Update the partition
        # ------------------------------------------------------------
        self.order     = order[:]
        self.direction = copy(direction)
        self.Units     = units.copy()
        self.data      = p_data

        return p_data
    #--- End: def

    def copy(self):
        '''

Return a deep copy.

Equivalent to ``copy.deepcopy(p)``.

:Returns:

    out :
        A deep copy.

**Examples**

>>> q = p.copy()

'''
        # ------------------------------------------------------------
        # Note that the partition's data is *not* deep copied, but a
        # deep copy is produced, if necessary, by the partition's
        # `conform` method, if and when it is called. Accessing the
        # data prior to calling conform could create undesired
        # modifications to other partitions.
        # ------------------------------------------------------------

        new = type(self)()

        new.data = self.data

        new.location  = self.location[:]
        new.order     = self.order[:] 
        new.shape     = self.shape[:]
        new.part      = self.part[:]
        new.direction = copy(self.direction)
        new.Units     = self.Units.copy()

        # DCH : copy _original and _save?

        return new
    #--- End: def

    def iterarray_indices(self):
        '''

Return an iterator over indices of the conformed array.

The array is not conformed.

:Returns:

    out : generator
        An iterator over indices of the conformed array.

**Examples**

>>> p.shape
[2, 1, 3]
>>> for index in p.iterarray_indices():
...     print index
...
(0, 0, 0)
(0, 0, 1)
(0, 0, 2)
(1, 0, 0)
(1, 0, 1)
(1, 0, 2)

'''
        return iterindices([(0, n) for n in self.shape])
    #--- End: def

    def itermaster_indices(self):
        '''

Return an iterator over indices of the master array which are spanned
by the conformed array.

The array is not conformed.

:Returns:

    out : generator
        An iterator over indices of the master array which are spanned
        by the conformed array.

**Examples**

>>> p.location
[(3, 5), (0, 1), (0, 3)]
>>> for index in p.itermaster_indices():
...     print index
...
(3, 0, 0)
(3, 0, 1)
(3, 0, 2)
(4, 0, 0)
(4, 0, 1)
(4, 0, 2)

'''
        return iterindices(self.location)
    #--- End: def

    def new_part(self, indices, dim2position, data_direction):
        '''

Return the `part` attribute updated for new indices.

Does not change the partition in place.

:Parameters:

    indices : list

    dim2position : dict

    data_direction : dict

:Returns:

    out : list

**Examples**

>>> p.part = p.new_part(indices, dim2position, data_direction)

'''   
        # ------------------------------------------------------------
        # If a dimension runs in the wrong direction so change its
        # index to account for this.
        #
        # For example, if a dimension with the wrong direction has
        # size 10 and its index is slice(3,8,2) then after the
        # direction is set correctly, the index needs to changed to
        # slice(6,0,-2):
        #
        # >>> a = [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
        # >>> a[slice(3, 8, 2)]          
        # [6, 4, 2]
        # >>> a.reverse()
        # >>> print a
        # >>> a = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        # >>> a[slice(6, 0, -2)]    
        # [6, 4, 2]
        # ------------------------------------------------------------
        direction = self.direction

        if self.data.ndim > 0:
            indices = indices[:]
            for dim, i in dim2position.iteritems():
            
                if not (dim in direction and
                        direction[dim] != data_direction[dim]):
                    continue
            
                # Still here? Then this dimension runs in the wrong
                # direction

                # Reset the direction
                direction[dim] = data_direction[dim]
            
                # Modify the index to account for the changed
                # direction
                size = self.shape[self.order.index(dim)]

                if isinstance(indices[i], slice):
                    start, stop, step = indices[i].indices(size)
                    # Note that step is assumed to be always +ve here
                    div, mod = divmod(stop-start-1, step)
                    start = size - 1 - start
                    stop  = start - div*step - 1
                    if stop < 0: stop = None
                    indices[i] = slice(start, stop, -step)
                else:
                    size -= 1
                    indices[i] = [size-j for j in indices[i]]
            #--- End: for
        #--- End: if

        # Reorder the new indices
        indices = [(indices[dim2position[dim]] 
                    if dim in dim2position else
                    slice(None))
                   for dim in self.order]
      
        if not self.part:
            return indices
    
        # Still here?
        p_part = []
        for part, index, size in zip(self.part,
                                     indices, 
                                     self.data.shape):

            if isinstance(part, slice):
                if isinstance(index, slice):

                    start , stop , step  = part.indices(size)

                    size1, mod = divmod(stop-start-1, step)            

                    start1, stop1, step1 = index.indices(size1+1)

                    size2, mod = divmod(stop1-start1, step1)

                    if mod != 0:
                        size2 += 1
                
                    start += start1 * step
                    step  *= step1
                    stop   = start + (size2-1)*step
                    if step > 0:
                        stop += 1
                    else:
                        stop -= 1
                    if stop < 0:
                        stop = None
                    p_part.append(slice(start, stop, step))
                    continue
                else:
                    new_part = range(*part.indices(size))
                    new_part = [new_part[i] for i in index]
            else:
                if isinstance(index, slice):
                    new_part = part[index]
                else:
                    new_part = [part[i] for i in index]
            #--- End: if
    
            # Still here? Then the new element of p_part is a list of
            # integers, so let's see if we can convert it to a slice
            # before appending it.
            if len(new_part) == 1:
                # Convert a single element list to a slice object
                new_part = new_part[0]
                new_part = slice(new_part, new_part+1, 1)
            else:                
                step = new_part[1] - new_part[0]
                if step:
                    if step > 0:
                        start, stop = new_part[0], new_part[-1]+1
                    else:
                        start, stop = new_part[0], new_part[-1]-1
                        if new_part == range(start, stop, step):
                            if stop < 0: stop = None
                            new_part = slice(start, stop, step)
            #--- End: if

            p_part.append(new_part)
        #--- End: for
    
        return p_part
    #--- End: def

    def to_disk(self):
        '''

Store the partition's data in a temporary file on disk in place.

Assumes that the partition's data is currently in memory, but this is
not checked.

:Returns:

    None

**Examples**

>>> p.to_disk()

'''
        self.data = FileArray(self.data)  

        _temporary_files.add(self.data._partition_file)
    #--- End: if

    def update_from(self, other):
        '''

Completely update the partition with another partition's attributes in
place.

The updated partition is always independent of the other partition.

:Parameters:

    other : Partition

:Returns:

    None

**Examples**

>>> p.update_from(q)

'''
        self.data = other.data

        self.location  = other.location[:]
        self.order     = other.order[:] 
        self.shape     = other.shape[:]
        self.part      = other.part[:]
        self.direction = copy(other.direction)
        self.Units     = other.Units.copy()

        self._original = None
        self._save     = None
    #--- End: def

#--- End: class
